/**
 * Copyright 2023 Huawei Technologies Co., Ltd
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
#include "trans_tensordesc.h"

#include <list>
#include <string>
#include <vector>
#include "securec.h"

#include "infra/base/assertion.h"
#include "infra/math/fp16_t.h"

#include "framework/common/types.h"
#include "framework/infra/log/log.h"
#include "framework/graph/utils/tensor_utils.h"

#include "common/math/math_util.h"
#include "cast_check_utils.h"


using namespace std;
using namespace ge;

namespace hiai {
HCS_API_EXPORT Status InitTensorDescriptor(const ge::TensorDesc& tensor, ccTensor_t& ccTensor)
{
    Status ret = SUCCESS;

    ge::Format format = tensor.GetFormat();
    int32_t originDataType = static_cast<int32_t>(tensor.GetDataType());
    std::vector<int64_t> dim = tensor.GetShape().GetDims();

    // format
    HIAI_EXPECT_IN_RANGE_R(format, FORMAT_NCHW, FORMAT_RESERVED - 1, FAILED);
    // data type
    HIAI_EXPECT_IN_RANGE_R(originDataType, CC_DATA_FLOAT, CC_DATA_RESERVED - 1, FAILED);

    uint32_t real_dim_cnt = 0;
    (void)ge::TensorUtils::GetRealDimCnt(tensor, real_dim_cnt);
    if (real_dim_cnt > CC_DIM_MAX) {
        FMK_LOGE("param is invalid, real_dim_cnt:%u", real_dim_cnt);
        return FAILED;
    }
    ccTensor.realDimCnt = static_cast<int32_t>(real_dim_cnt);

    DataType_t dataType = tagDataType(originDataType);
    if (format == FORMAT_ND) {
        int32_t realDim[CC_DIM_MAX] = {0};
        uint32_t i = 0;
        for (auto dim_temp : dim) {
            if (i >= real_dim_cnt) {
                break;
            }
            realDim[i] = static_cast<int32_t>(dim_temp);
            i++;
        }
        auto ccRet = SetTensorNdDescriptor(ccTensor, dataType, real_dim_cnt, realDim);
        if (ccRet != SUCCESS) {
            FMK_LOGE("Call SetTensorNdDescriptor failed. ccRet = %d", ccRet);
            return FAILED;
        }
        return ret;
    }

    std::vector<int32_t> dimVector;
    if (TransferDim(dim, dimVector) != SUCCESS) {
        FMK_LOGE("TransferDim failed.");
        return FAILED;
    }

    if (dim.size() == 1 && dim[0] == 0) {
        return SUCCESS;
    }

    if (format == FORMAT_NDHWC || format == FORMAT_NCDHW) {
        ret = SetNDTensorDescriptor(ccTensor, format, dataType, dimVector);
    } else if (format == FORMAT_NHWC) {
        std::vector<int32_t> dims = {dimVector.at(0), dimVector.at(3), dimVector.at(1), dimVector.at(2)};
        ret = SetTensor4dDescriptor(ccTensor, format, dataType, dims);
    } else {
        std::vector<int32_t> dims = {dimVector.at(0), dimVector.at(1), dimVector.at(2), dimVector.at(3)};
        ret = SetTensor4dDescriptor(ccTensor, format, dataType, dims);
    }
    if (ret != SUCCESS) {
        FMK_LOGE("SetTensorDescriptor failed. ret:%d, format:%d", ret, static_cast<int>(format));
        return FAILED;
    }

    return ret;
}

HCS_API_EXPORT Status TransTensor(const ge::TensorDesc& xDesc, const ge::BaseBuffer& input,
    const ge::TensorDesc& yDesc, ge::BaseBuffer& output)
{
    HIAI_EXPECT_TRUE_R(CastCheckUtils::CheckInputOutputTensor(xDesc, yDesc), FAILED);
    ccTensor_t xDescCC = {};
    ccTensor_t yDescCC = {};
    uint32_t ySizeInBytes = 0;

    Status ret = InitTensorDescriptor(xDesc, xDescCC);
    if (ret != SUCCESS) {
        FMK_LOGE("get input ccTensor descriptor failed.");
        return FAILED;
    }

    ret = InitTensorDescriptor(yDesc, yDescCC);
    if (ret != SUCCESS) {
        FMK_LOGE("call Init out_desc TensorDescriptor failed.");
        return FAILED;
    }

    // 获取数据大小
    ret = GetTensorMemorySizeInBytes(yDescCC, ySizeInBytes);
    if (ret != SUCCESS || ySizeInBytes == 0) {
        FMK_LOGE("GetTensorMemorySizeInBytes failed. ret = %d", ret);
        return FAILED;
    }

    ret = TransTensor(xDescCC, input, yDescCC, output, ySizeInBytes);
    if (ret != SUCCESS || ySizeInBytes == 0) {
        FMK_LOGE("TransTensor failed. ret = %d", ret);
        return FAILED;
    }
    return SUCCESS;
}
} // namespace hiai